Conversation
…y will come back to this
…pe on key for train test split
…other density estimators not just MAFs as in version 1
There was a problem hiding this comment.
Pull request overview
This pull request represents a major version update (1.4.2 → 2.0.0) that migrates margarine from TensorFlow to JAX. The refactoring introduces significant architectural improvements and adds new density estimators.
Key changes:
- Complete migration from TensorFlow to JAX/Flax for improved performance and GPU acceleration
- Added NICE and RealNVP normalizing flow implementations
- Introduced a
BaseDensityEstimatorabstract class providing a common API across all estimators - Restructured codebase with modular organization (base/, estimators/, utils/)
- Rewrote clustering implementation to support any density estimator type (Piecewise Normalizing Flows)
- Added JAX-based K-means clustering implementation
- Updated documentation to use MkDocs instead of Sphinx
Reviewed changes
Copilot reviewed 40 out of 52 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_utils.py | New tests for utility functions (transformations, bounds) |
| tests/test_importance_sampling.py | Rewritten importance sampling tests using JAX and RealNVP |
| tests/test_estimators.py | New comprehensive tests for NICE, RealNVP, and KDE estimators |
| tests/test_cluster.py | New tests for clustered/piecewise normalizing flows |
| margarine/statistics.py | New module for KL divergence, model dimensionality, and integration |
| margarine/utils/utils.py | JAX implementations of transformation and bounds estimation utilities |
| margarine/utils/kmeans.py | JAX-based K-means clustering implementation |
| margarine/base/baseflow.py | Abstract base class defining common density estimator interface |
| margarine/estimators/realnvp.py | RealNVP normalizing flow implementation in JAX/Flax |
| margarine/estimators/nice.py | NICE normalizing flow implementation in JAX/Flax |
| margarine/estimators/kde.py | KDE implementation with JAX support and conditional sampling |
| margarine/estimators/clustered.py | Piecewise NF wrapper supporting any base estimator |
| pyproject.toml | Updated dependencies from TensorFlow to JAX/Flax/Optax |
| mkdocs.yaml | New MkDocs configuration replacing Sphinx |
| docs/tutorials.md | New comprehensive tutorials for v2.0.0 |
| README.md | Rewritten README with updated examples and information |
| margarine/_version.py | Version bump to 2.0.0 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This was referenced Jan 9, 2026
Closed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR is designed to update
margarineto use JAX and include several different normalising flows.Impact on issues
Fixes #67 and #29.
#8 is no longer relevant.
I'm leaving error calculation on marginal statistics up to the user, closing #28. Although I will add a discussion to the documentation.
Should also fix #56
Key changes
BaseDensityEstimatorthat all density estimators inherit from. This means that each density estimator has an expected set of methods.clusterMAFclass has been rewritten into theclusterclass inmargarine/estimators/clustered.py. It takes advantage of the common API for each density estimator to allow users to build Piecewise NFs with any other implemented NF architecture (e.g. users can now build RealNVP PNFs, NICE PNFs, and even Piecewise KDEs).margarine/base/, density estimators are kept inmargarine/estimators/and utilities are kept inmargarine/utils/.jax.scipy.statsdoesn't have one and it is needed for Piecewise Normalising Flows.__call__function for KDE. To transform samples from the unit hypercube on to the KDE you need conditional inverse transform sampling and this needs to be reimplemented in JAX.Checklist:
python -m pytest)